# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for mcore activation UT of inference."""
import numpy as np


def get_init_params(batch_size, seq_length, hidden_size):
    """Generate initialization parameters"""
    np.random.seed(2025)
    tensor_shape = (batch_size * seq_length, hidden_size)
    return {
        "multiplicand_input": np.random.normal(loc=0, scale=0.01, size=tensor_shape),
        "multiplier_input": np.random.normal(loc=0, scale=0.01, size=tensor_shape),
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output = np.array(
        [[-4.31316266e-06, -2.41535890e-05, 6.89079752e-05,
          1.45538979e-05, 6.29688247e-06, 5.90012678e-05,
          7.37459704e-05, -1.22185361e-06, -6.93243855e-05,
          1.63602035e-05, -2.48273700e-05, -9.54318239e-06,
          8.44692986e-05, 5.02472005e-07, 3.43676987e-07,
          -9.27454130e-06, 3.08749877e-05, -1.56599549e-06,
          3.97102622e-06, 1.24504941e-05, -1.84118126e-05,
          1.72123819e-05, -8.68260213e-06, 1.69471350e-05,
          -2.67131046e-08, -6.37491394e-06, 1.23318944e-06,
          -1.65603637e-06, 9.48613815e-05, 3.86531656e-06,
          6.14031233e-05, 8.45668455e-06],
         [4.65761531e-07, -1.12804482e-04, 3.93263363e-05,
          7.50022627e-06, 1.36806302e-05, 3.19279970e-05,
          3.60762601e-06, 2.29857233e-05, 4.00640856e-05,
          -5.20323592e-05, -9.11975221e-05, 4.62121716e-05,
          -4.56244834e-06, 3.63390682e-05, 1.24168519e-05,
          3.62076971e-05, -2.75885723e-05, 6.72348542e-05,
          -4.28602016e-06, 1.12097026e-04, 2.29266589e-05,
          1.43112302e-05, 8.33341619e-05, -4.13142443e-05,
          1.63179011e-05, 1.78804237e-06, 1.38807254e-05,
          1.86441393e-05, 1.30439430e-05, 6.94948394e-05,
          -6.27041436e-06, -4.93406405e-05],
         [-1.05269492e-05, 9.91170873e-06, 5.20948051e-05,
          7.95675277e-08, -2.87475159e-05, 1.17885365e-05,
          -7.47259401e-05, -5.62971036e-05, 3.64419720e-05,
          -1.44198275e-05, 1.63885506e-05, 4.35137463e-06,
          5.07772056e-05, 1.54796871e-05, 2.49143668e-05,
          -2.12392933e-05, 1.52873807e-04, 8.50886590e-06,
          -4.31083572e-06, 1.10748319e-04, -7.86917894e-07,
          3.82080361e-05, -1.06906818e-05, 5.38622999e-06,
          -4.23369565e-05, 6.87175198e-05, -1.57163358e-05,
          5.43466498e-08, 9.29316320e-06, 1.44606520e-06,
          6.71854359e-05, -1.24226999e-05],
         [-1.39254716e-05, -1.99644751e-06, -3.00226588e-06,
          1.44240712e-05, 1.26534913e-04, 8.84997717e-05,
          -2.42607148e-05, -5.92464530e-05, -1.14083230e-06,
          -1.39598096e-05, -9.34223954e-06, -5.05931803e-06,
          -1.57249437e-04, 1.77709571e-05, -3.99018245e-05,
          2.14316351e-05, 3.85554195e-05, 3.20778636e-05,
          3.10766372e-05, 8.15260023e-07, -5.65088776e-05,
          2.74936920e-05, 6.60025471e-05, -6.72559599e-06,
          2.84619691e-05, -1.98665603e-05, -1.99393817e-05,
          1.13649323e-04, -4.10045104e-05, 2.96183098e-05,
          -2.66621555e-05, 1.01670325e-04]], dtype=np.float32
    )

    return {
        "output": output,
    }


def get_gpu_data() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output = np.array(
        [[-4.2915e-06, -2.4199e-05, 6.8665e-05, 1.4484e-05,
          6.3181e-06, 5.9128e-05, 7.3910e-05, -1.2517e-06,
          -6.9141e-05, 1.6332e-05, -2.4915e-05, -9.5367e-06,
          8.4400e-05, 4.7684e-07, 3.5763e-07, -9.2983e-06,
          3.0756e-05, -1.5497e-06, 3.9935e-06, 1.2457e-05,
          -1.8477e-05, 1.7285e-05, -8.7023e-06, 1.6928e-05,
          -0.0000e+00, -6.3181e-06, 1.2517e-06, -1.6689e-06,
          9.5367e-05, 3.8743e-06, 6.1512e-05, 8.4639e-06],
         [4.7684e-07, -1.1253e-04, 3.9339e-05, 7.4506e-06,
          1.3709e-05, 3.1948e-05, 3.5763e-06, 2.3007e-05,
          4.0293e-05, -5.2214e-05, -9.1076e-05, 4.6015e-05,
          -4.5896e-06, 3.6240e-05, 1.2398e-05, 3.6240e-05,
          -2.7418e-05, 6.7234e-05, -4.2915e-06, 1.1206e-04,
          2.2888e-05, 1.4365e-05, 8.3923e-05, -4.1246e-05,
          1.6332e-05, 1.7881e-06, 1.3828e-05, 1.8716e-05,
          1.3113e-05, 6.9618e-05, -6.2585e-06, -4.9353e-05],
         [-1.0550e-05, 9.8944e-06, 5.2214e-05, 5.9605e-08,
          -2.8729e-05, 1.1802e-05, -7.4863e-05, -5.6267e-05,
          3.6478e-05, -1.4484e-05, 1.6332e-05, 4.3511e-06,
          5.0545e-05, 1.5497e-05, 2.4915e-05, -2.1219e-05,
          1.5354e-04, 8.5235e-06, -4.2915e-06, 1.1063e-04,
          -7.7486e-07, 3.8147e-05, -1.0669e-05, 5.3644e-06,
          -4.2439e-05, 6.8188e-05, -1.5736e-05, 5.9605e-08,
          9.3579e-06, 1.4305e-06, 6.6757e-05, -1.2398e-05],
         [-1.4007e-05, -2.0266e-06, -2.9802e-06, 1.4424e-05,
          1.2589e-04, 8.7738e-05, -2.4199e-05, -5.9605e-05,
          -1.1325e-06, -1.3947e-05, -9.3579e-06, -5.0664e-06,
          -1.5736e-04, 1.7881e-05, -3.9816e-05, 2.1458e-05,
          3.8385e-05, 3.2187e-05, 3.1233e-05, 8.3447e-07,
          -5.6744e-05, 2.7657e-05, 6.6280e-05, -6.6757e-06,
          2.8610e-05, -1.9789e-05, -1.9908e-05, 1.1396e-04,
          -4.1008e-05, 2.9564e-05, -2.6703e-05, 1.0204e-04]], dtype=np.float16
    )

    return {
        "output": output,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_data()
